import type { Base } from "../Base";
import { mat4, vec3 } from "wgpu-matrix";
import normalMapWGSL from "./normalMap.wgsl";
import { createMeshRenderable } from "../../meshes/mesh";
import { createBoxMeshWithTangents } from "../../meshes/box";
import {
  createBindGroupDescriptor,
  create3DRenderPipeline,
  createTextureFromImage,
} from "./utils";
import { quitIfWebGPUNotAvailable } from "../../utils";

const MAT4X4_BYTES = 64;
interface GUISettings {
  "Bump Mode":
    | "Albedo Texture"
    | "Normal Texture"
    | "Depth Texture"
    | "Normal Map"
    | "Parallax Scale"
    | "Steep Parallax";
  cameraPosX: number;
  cameraPosY: number;
  cameraPosZ: number;
  lightPosX: number;
  lightPosY: number;
  lightPosZ: number;
  lightIntensity: number;
  depthScale: number;
  depthLayers: number;
  Texture: string;
  "Reset Light": () => void;
}
export default class NormalMap implements Base {
  name = "NormalMap";
  description?: string | undefined =
    "This example demonstrates multiple different methods that employ fragment shaders to achieve additional perceptual depth on the surface of a cube mesh.\
     Demonstrated methods include normal mapping, parallax mapping, and steep parallax mapping.";
  settings: GUISettings = {
    "Bump Mode": "Normal Map",
    cameraPosX: 0.0,
    cameraPosY: 0.8,
    cameraPosZ: -1.4,
    lightPosX: 1.7,
    lightPosY: 0.7,
    lightPosZ: -1.9,
    lightIntensity: 5.0,
    depthScale: 0.05,
    depthLayers: 16,
    Texture: "Spiral",
    "Reset Light": () => {
      return;
    },
  };
  onChangeTexture?: () => void;
  async run(canvas: HTMLCanvasElement) {
    const self = this;
    const adapter = await navigator.gpu?.requestAdapter({
      featureLevel: "compatibility",
    });
    const device = await adapter?.requestDevice();
    quitIfWebGPUNotAvailable(adapter, device ?? null);
    const context = canvas.getContext("webgpu") as GPUCanvasContext;
    const devicePixelRatio = window.devicePixelRatio;
    canvas.width = canvas.clientWidth * devicePixelRatio;
    canvas.height = canvas.clientHeight * devicePixelRatio;
    const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
    context.configure({
      device: device!,
      format: presentationFormat,
    });

    // Create normal mapping resources and pipeline
    const depthTexture = device!.createTexture({
      size: [canvas.width, canvas.height],
      format: "depth24plus",
      usage: GPUTextureUsage.RENDER_ATTACHMENT,
    });

    const spaceTransformsBuffer = device!.createBuffer({
      // Buffer holding projection, view, and model matrices plus padding bytes
      size: MAT4X4_BYTES * 4,
      usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
    });

    const mapInfoBuffer = device!.createBuffer({
      // Buffer holding mapping type, light uniforms, and depth uniforms
      size: Float32Array.BYTES_PER_ELEMENT * 8,
      usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
    });
    const mapInfoArray = new ArrayBuffer(mapInfoBuffer.size);
    const mapInfoView = new DataView(mapInfoArray, 0, mapInfoArray.byteLength);

    // Fetch the image and upload it into a GPUTexture.
    let woodAlbedoTexture: GPUTexture;
    {
      const response = await fetch("../src/assets/imgs/wood_albedo.png");
      const imageBitmap = await createImageBitmap(await response.blob());
      woodAlbedoTexture = createTextureFromImage(device!, imageBitmap);
    }

    let spiralNormalTexture: GPUTexture;
    {
      const response = await fetch("../src/assets/imgs/spiral_normal.png");
      const imageBitmap = await createImageBitmap(await response.blob());
      spiralNormalTexture = createTextureFromImage(device!, imageBitmap);
    }

    let spiralHeightTexture: GPUTexture;
    {
      const response = await fetch("../src/assets/imgs/spiral_height.png");
      const imageBitmap = await createImageBitmap(await response.blob());
      spiralHeightTexture = createTextureFromImage(device!, imageBitmap);
    }

    let toyboxNormalTexture: GPUTexture;
    {
      const response = await fetch("../src/assets/imgs/toybox_normal.png");
      const imageBitmap = await createImageBitmap(await response.blob());
      toyboxNormalTexture = createTextureFromImage(device!, imageBitmap);
    }

    let toyboxHeightTexture: GPUTexture;
    {
      const response = await fetch("../src/assets/imgs/toybox_height.png");
      const imageBitmap = await createImageBitmap(await response.blob());
      toyboxHeightTexture = createTextureFromImage(device!, imageBitmap);
    }

    let brickwallAlbedoTexture: GPUTexture;
    {
      const response = await fetch("../src/assets/imgs/brickwall_albedo.png");
      const imageBitmap = await createImageBitmap(await response.blob());
      brickwallAlbedoTexture = createTextureFromImage(device!, imageBitmap);
    }

    let brickwallNormalTexture: GPUTexture;
    {
      const response = await fetch("../src/assets/imgs/brickwall_normal.png");

      const imageBitmap = await createImageBitmap(await response.blob());
      brickwallNormalTexture = createTextureFromImage(device!, imageBitmap);
    }

    let brickwallHeightTexture: GPUTexture;
    {
      const response = await fetch("../src/assets/imgs/brickwall_height.png");

      const imageBitmap = await createImageBitmap(await response.blob());
      brickwallHeightTexture = createTextureFromImage(device!, imageBitmap);
    }

    // Create a sampler with linear filtering for smooth interpolation.
    const sampler = device!.createSampler({
      magFilter: "linear",
      minFilter: "linear",
    });

    const renderPassDescriptor: GPURenderPassDescriptor = {
      colorAttachments: [
        {
          view: undefined as unknown as GPUTextureView, // Assigned later
          clearValue: [0, 0, 0, 1],
          loadOp: "clear",
          storeOp: "store",
        },
      ],
      depthStencilAttachment: {
        view: depthTexture.createView(),

        depthClearValue: 1.0,
        depthLoadOp: "clear",
        depthStoreOp: "store",
      },
    };

    const box = createMeshRenderable(
      device!,
      createBoxMeshWithTangents(1.0, 1.0, 1.0)
    );

    // Uniform bindGroups and bindGroupLayout
    const frameBGDescriptor = createBindGroupDescriptor(
      [0, 1],
      [
        GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT,
        GPUShaderStage.FRAGMENT | GPUShaderStage.VERTEX,
      ],
      ["buffer", "buffer"],
      [{ type: "uniform" }, { type: "uniform" }],
      [[{ buffer: spaceTransformsBuffer }, { buffer: mapInfoBuffer }]],
      "Frame",
      device!
    );

    // Texture bindGroups and bindGroupLayout
    const surfaceBGDescriptor = createBindGroupDescriptor(
      [0, 1, 2, 3],
      [GPUShaderStage.FRAGMENT],
      ["sampler", "texture", "texture", "texture"],
      [
        { type: "filtering" },
        { sampleType: "float" },
        { sampleType: "float" },
        { sampleType: "float" },
      ],
      // Multiple bindgroups that accord to the layout defined above
      [
        [
          sampler,
          woodAlbedoTexture.createView(),
          spiralNormalTexture.createView(),
          spiralHeightTexture.createView(),
        ],
        [
          sampler,
          woodAlbedoTexture.createView(),
          toyboxNormalTexture.createView(),
          toyboxHeightTexture.createView(),
        ],
        [
          sampler,
          brickwallAlbedoTexture.createView(),
          brickwallNormalTexture.createView(),
          brickwallHeightTexture.createView(),
        ],
      ],
      "Surface",
      device!
    );

    const aspect = canvas.width / canvas.height;
    const projectionMatrix = mat4.perspective(
      (2 * Math.PI) / 5,
      aspect,
      0.1,
      10.0
    );

    function getViewMatrix() {
      return mat4.lookAt(
        [
          self.settings.cameraPosX,
          self.settings.cameraPosY,
          self.settings.cameraPosZ,
        ],
        [0, 0, 0],
        [0, 1, 0]
      );
    }

    function getModelMatrix() {
      const modelMatrix = mat4.create();
      mat4.identity(modelMatrix);
      const now = Date.now() / 1000;
      mat4.rotateY(modelMatrix, now * -0.5, modelMatrix);
      return modelMatrix;
    }

    // Change the model mapping type
    const getMode = (): number => {
      switch (self.settings["Bump Mode"]) {
        case "Albedo Texture":
          return 0;
        case "Normal Texture":
          return 1;
        case "Depth Texture":
          return 2;
        case "Normal Map":
          return 3;
        case "Parallax Scale":
          return 4;
        case "Steep Parallax":
          return 5;
      }
    };

    const texturedCubePipeline = create3DRenderPipeline(
      device!,
      "NormalMappingRender",
      [frameBGDescriptor.bindGroupLayout, surfaceBGDescriptor.bindGroupLayout],
      normalMapWGSL,
      // Position,   normal       uv           tangent      bitangent
      ["float32x3", "float32x3", "float32x2", "float32x3", "float32x3"],
      normalMapWGSL,
      presentationFormat,
      true
    );

    let currentSurfaceBindGroup = 0;
    const onChangeTexture = () => {
      // Map texture names to their corresponding bind group indices
      const textureIndexMap: { [key: string]: number } = {
        Spiral: 0,
        Toybox: 1,
        BrickWall: 2,
      };
      currentSurfaceBindGroup = textureIndexMap[self.settings.Texture];
    };

    this.onChangeTexture = onChangeTexture;

    function frame() {
      // Update spaceTransformsBuffer
      const viewMatrix = getViewMatrix();
      const worldViewMatrix = mat4.mul(viewMatrix, getModelMatrix());
      const worldViewProjMatrix = mat4.mul(projectionMatrix, worldViewMatrix);
      const matrices = new Float32Array([
        ...worldViewProjMatrix,
        ...worldViewMatrix,
      ]);

      // Update mapInfoBuffer
      const lightPosWS = vec3.create(
        self.settings.lightPosX,
        self.settings.lightPosY,
        self.settings.lightPosZ
      );
      const lightPosVS = vec3.transformMat4(lightPosWS, viewMatrix);
      const mode = getMode();
      device!.queue.writeBuffer(
        spaceTransformsBuffer,
        0,
        matrices.buffer,
        matrices.byteOffset,
        matrices.byteLength
      );

      // struct MapInfo {
      //   lightPosVS: vec3f,
      //   mode: u32,
      //   lightIntensity: f32,
      //   depthScale: f32,
      //   depthLayers: f32,
      // }
      mapInfoView.setFloat32(0, lightPosVS[0], true);
      mapInfoView.setFloat32(4, lightPosVS[1], true);
      mapInfoView.setFloat32(8, lightPosVS[2], true);
      mapInfoView.setUint32(12, mode, true);
      mapInfoView.setFloat32(16, self.settings.lightIntensity, true);
      mapInfoView.setFloat32(20, self.settings.depthScale, true);
      mapInfoView.setFloat32(24, self.settings.depthLayers, true);
      device!.queue.writeBuffer(mapInfoBuffer, 0, mapInfoArray);

      // @ts-ignore
      renderPassDescriptor.colorAttachments[0].view = context
        .getCurrentTexture()
        .createView();

      const commandEncoder = device!.createCommandEncoder();
      const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor);
      // Draw textured Cube
      passEncoder.setPipeline(texturedCubePipeline);
      passEncoder.setBindGroup(0, frameBGDescriptor.bindGroups[0]);
      passEncoder.setBindGroup(
        1,
        surfaceBGDescriptor.bindGroups[currentSurfaceBindGroup]
      );
      passEncoder.setVertexBuffer(0, box.vertexBuffer);
      passEncoder.setIndexBuffer(box.indexBuffer, "uint16");
      passEncoder.drawIndexed(box.indexCount);
      passEncoder.end();
      device!.queue.submit([commandEncoder.finish()]);

      requestAnimationFrame(frame);
    }
    requestAnimationFrame(frame);
  }

  onUpdate(data: any): void {
    this.settings = data;
    if (this.onChangeTexture) {
      this.onChangeTexture();
    }
  }
}
